from torch import nn

from NN.nn_module import NNM

class CNNBN(NNM):
    def module(self):
        self.nn_stack = nn.Sequential(
            nn.Conv2d(in_channels=self.in_channels, out_channels=32, kernel_size=3, stride=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(10816, 256),
            nn.ReLU(),
            nn.Linear(256, self.y_shape),
            nn.LogSoftmax(dim=1),
        )
